{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.6.1\n",
      "\n",
      "torch 1.2.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Runs on CPU or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gradient Clipping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Certain types of deep neural networks, especially, simple ones without any other type regularization and a relatively large number of layers, can suffer from exploding gradient problems. The exploding gradient problem is a scenario where large loss gradients accumulate during backpropagation, which will eventually result in very large weight updates during training. As a consequence, the updates will be very unstable and fluctuate a lot, which often causes severe problems during training. This is also a particular problem for unbounded activation functions such as ReLU.\n",
    "\n",
    "One common, classic technique for avoiding exploding gradient problems is the so-called gradient clipping approach. Here, we simply set gradient values above or below a certain threshold to a user-specified min or max value. In PyTorch, there are several ways for performing gradient clipping. \n",
    "\n",
    "**1 - Basic Clipping**\n",
    "\n",
    "The simplest approach to gradient clipping in PyTorch is by using the [`torch.nn.utils.clip_grad_value_`](https://pytorch.org/docs/stable/nn.html?highlight=clip#torch.nn.utils.clip_grad_value_) function. For example, if we have instantiated a PyTorch model from a model class based on `torch.nn.Module` (as usual), we can add the following line of code in order to clip the gradients to [-1, 1] range:\n",
    "\n",
    "```python\n",
    "torch.nn.utils.clip_grad_value_(parameters=model.parameters(), \n",
    "                                clip_value=1.)\n",
    "\n",
    "```\n",
    "\n",
    "However, notice that via this approach, we can only specify a single clip value, which will be used for both the upper and lower bound such that gradients will be clipped to the range [-`clip_value`, `clip_value`].\n",
    "\n",
    "\n",
    "**2 - Custom Lower and Upper Bounds**\n",
    "\n",
    "If we want to clip the gradients to an unsymmetric interval around zero, say [-0.1, 1.0], we can take a different approach by defining a backwards hook:\n",
    "\n",
    "```python\n",
    "for param in model.parameters():\n",
    "    param.register_hook(lambda gradient: torch.clamp(gradient, -0.1, 1.0))\n",
    "```\n",
    "\n",
    "This backward hook only needs to be defined once after instantiating the model. Then, each time after calling the `backward` method, it will clip the gradients before running the `model.step()` method.\n",
    "\n",
    "**3 - Norm-clipping**\n",
    "\n",
    "Lastly, there's a third clipping option, [`torch.nn.utils.clip_grad_norm_`](https://pytorch.org/docs/stable/nn.html?highlight=clip#torch.nn.utils.clip_grad_norm_), which clips the gradients using a vector norm as follows:\n",
    "\n",
    "\n",
    "> `torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)`\n",
    "\n",
    ">Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([64, 1, 28, 28])\n",
      "Image label dimensions: torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.01\n",
    "num_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "# Architecture\n",
    "num_features = 784\n",
    "num_hidden_1 = 256\n",
    "num_hidden_2 = 128\n",
    "num_hidden_3 = 64\n",
    "num_hidden_4 = 32\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=transforms.ToTensor(),\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=batch_size, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=batch_size, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(net, data_loader):\n",
    "    net.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for features, targets in data_loader:\n",
    "            features = features.view(-1, 28*28).to(device)\n",
    "            targets = targets.to(device)\n",
    "            logits, probas = net(features)\n",
    "            _, predicted_labels = torch.max(probas, 1)\n",
    "            num_examples += targets.size(0)\n",
    "            correct_pred += (predicted_labels == targets).sum()\n",
    "        return correct_pred.float()/num_examples * 100\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "class MultilayerPerceptron(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_features, num_classes):\n",
    "        super(MultilayerPerceptron, self).__init__()\n",
    "        \n",
    "        ### 1st hidden layer\n",
    "        self.linear_1 = torch.nn.Linear(num_features, num_hidden_1)\n",
    "\n",
    "        ### 2nd hidden layer\n",
    "        self.linear_2 = torch.nn.Linear(num_hidden_1, num_hidden_2)\n",
    "\n",
    "        ### 3rd hidden layer\n",
    "        self.linear_3 = torch.nn.Linear(num_hidden_2, num_hidden_3)\n",
    "        \n",
    "        ### 4th hidden layer\n",
    "        self.linear_4 = torch.nn.Linear(num_hidden_3, num_hidden_4)\n",
    "        \n",
    "        \n",
    "        ### Output layer\n",
    "        self.linear_out = torch.nn.Linear(num_hidden_4, num_classes)\n",
    "\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.linear_1(x)\n",
    "        out = F.relu(out)\n",
    "        out = self.linear_2(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.linear_3(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.linear_4(out)\n",
    "        out = F.relu(out)\n",
    "        logits = self.linear_out(out)\n",
    "        probas = F.log_softmax(logits, dim=1)\n",
    "        return logits, probas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 - Basic Clipping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/938 | Cost: 2.3054\n",
      "Epoch: 001/010 | Batch 050/938 | Cost: 0.6427\n",
      "Epoch: 001/010 | Batch 100/938 | Cost: 0.3220\n",
      "Epoch: 001/010 | Batch 150/938 | Cost: 0.3492\n",
      "Epoch: 001/010 | Batch 200/938 | Cost: 0.4505\n",
      "Epoch: 001/010 | Batch 250/938 | Cost: 0.1510\n",
      "Epoch: 001/010 | Batch 300/938 | Cost: 0.2062\n",
      "Epoch: 001/010 | Batch 350/938 | Cost: 0.1287\n",
      "Epoch: 001/010 | Batch 400/938 | Cost: 0.1714\n",
      "Epoch: 001/010 | Batch 450/938 | Cost: 0.3522\n",
      "Epoch: 001/010 | Batch 500/938 | Cost: 0.4268\n",
      "Epoch: 001/010 | Batch 550/938 | Cost: 0.0133\n",
      "Epoch: 001/010 | Batch 600/938 | Cost: 0.1868\n",
      "Epoch: 001/010 | Batch 650/938 | Cost: 0.2312\n",
      "Epoch: 001/010 | Batch 700/938 | Cost: 0.1471\n",
      "Epoch: 001/010 | Batch 750/938 | Cost: 0.1321\n",
      "Epoch: 001/010 | Batch 800/938 | Cost: 0.2776\n",
      "Epoch: 001/010 | Batch 850/938 | Cost: 0.2223\n",
      "Epoch: 001/010 | Batch 900/938 | Cost: 0.1812\n",
      "Epoch: 001/010 training accuracy: 94.72%\n",
      "Time elapsed: 0.25 min\n",
      "Epoch: 002/010 | Batch 000/938 | Cost: 0.2080\n",
      "Epoch: 002/010 | Batch 050/938 | Cost: 0.2177\n",
      "Epoch: 002/010 | Batch 100/938 | Cost: 0.1090\n",
      "Epoch: 002/010 | Batch 150/938 | Cost: 0.1225\n",
      "Epoch: 002/010 | Batch 200/938 | Cost: 0.2514\n",
      "Epoch: 002/010 | Batch 250/938 | Cost: 0.1093\n",
      "Epoch: 002/010 | Batch 300/938 | Cost: 0.0626\n",
      "Epoch: 002/010 | Batch 350/938 | Cost: 0.1242\n",
      "Epoch: 002/010 | Batch 400/938 | Cost: 0.0168\n",
      "Epoch: 002/010 | Batch 450/938 | Cost: 0.2678\n",
      "Epoch: 002/010 | Batch 500/938 | Cost: 0.1761\n",
      "Epoch: 002/010 | Batch 550/938 | Cost: 0.2607\n",
      "Epoch: 002/010 | Batch 600/938 | Cost: 0.1324\n",
      "Epoch: 002/010 | Batch 650/938 | Cost: 0.2334\n",
      "Epoch: 002/010 | Batch 700/938 | Cost: 0.1510\n",
      "Epoch: 002/010 | Batch 750/938 | Cost: 0.1456\n",
      "Epoch: 002/010 | Batch 800/938 | Cost: 0.2882\n",
      "Epoch: 002/010 | Batch 850/938 | Cost: 0.1485\n",
      "Epoch: 002/010 | Batch 900/938 | Cost: 0.2007\n",
      "Epoch: 002/010 training accuracy: 96.83%\n",
      "Time elapsed: 0.49 min\n",
      "Epoch: 003/010 | Batch 000/938 | Cost: 0.0550\n",
      "Epoch: 003/010 | Batch 050/938 | Cost: 0.0555\n",
      "Epoch: 003/010 | Batch 100/938 | Cost: 0.1040\n",
      "Epoch: 003/010 | Batch 150/938 | Cost: 0.2290\n",
      "Epoch: 003/010 | Batch 200/938 | Cost: 0.0506\n",
      "Epoch: 003/010 | Batch 250/938 | Cost: 0.1028\n",
      "Epoch: 003/010 | Batch 300/938 | Cost: 0.0381\n",
      "Epoch: 003/010 | Batch 350/938 | Cost: 0.1593\n",
      "Epoch: 003/010 | Batch 400/938 | Cost: 0.0637\n",
      "Epoch: 003/010 | Batch 450/938 | Cost: 0.0127\n",
      "Epoch: 003/010 | Batch 500/938 | Cost: 0.4391\n",
      "Epoch: 003/010 | Batch 550/938 | Cost: 0.0110\n",
      "Epoch: 003/010 | Batch 600/938 | Cost: 0.1959\n",
      "Epoch: 003/010 | Batch 650/938 | Cost: 0.1020\n",
      "Epoch: 003/010 | Batch 700/938 | Cost: 0.0206\n",
      "Epoch: 003/010 | Batch 750/938 | Cost: 0.2747\n",
      "Epoch: 003/010 | Batch 800/938 | Cost: 0.1192\n",
      "Epoch: 003/010 | Batch 850/938 | Cost: 0.0115\n",
      "Epoch: 003/010 | Batch 900/938 | Cost: 0.2476\n",
      "Epoch: 003/010 training accuracy: 97.65%\n",
      "Time elapsed: 0.74 min\n",
      "Epoch: 004/010 | Batch 000/938 | Cost: 0.0875\n",
      "Epoch: 004/010 | Batch 050/938 | Cost: 0.0335\n",
      "Epoch: 004/010 | Batch 100/938 | Cost: 0.0530\n",
      "Epoch: 004/010 | Batch 150/938 | Cost: 0.4291\n",
      "Epoch: 004/010 | Batch 200/938 | Cost: 0.0634\n",
      "Epoch: 004/010 | Batch 250/938 | Cost: 0.0437\n",
      "Epoch: 004/010 | Batch 300/938 | Cost: 0.0547\n",
      "Epoch: 004/010 | Batch 350/938 | Cost: 0.1602\n",
      "Epoch: 004/010 | Batch 400/938 | Cost: 0.1071\n",
      "Epoch: 004/010 | Batch 450/938 | Cost: 0.0351\n",
      "Epoch: 004/010 | Batch 500/938 | Cost: 0.0712\n",
      "Epoch: 004/010 | Batch 550/938 | Cost: 0.1261\n",
      "Epoch: 004/010 | Batch 600/938 | Cost: 0.1212\n",
      "Epoch: 004/010 | Batch 650/938 | Cost: 0.0802\n",
      "Epoch: 004/010 | Batch 700/938 | Cost: 0.0844\n",
      "Epoch: 004/010 | Batch 750/938 | Cost: 0.1496\n",
      "Epoch: 004/010 | Batch 800/938 | Cost: 0.1543\n",
      "Epoch: 004/010 | Batch 850/938 | Cost: 0.0182\n",
      "Epoch: 004/010 | Batch 900/938 | Cost: 0.0433\n",
      "Epoch: 004/010 training accuracy: 97.08%\n",
      "Time elapsed: 0.98 min\n",
      "Epoch: 005/010 | Batch 000/938 | Cost: 0.1570\n",
      "Epoch: 005/010 | Batch 050/938 | Cost: 0.0291\n",
      "Epoch: 005/010 | Batch 100/938 | Cost: 0.0363\n",
      "Epoch: 005/010 | Batch 150/938 | Cost: 0.0320\n",
      "Epoch: 005/010 | Batch 200/938 | Cost: 0.0322\n",
      "Epoch: 005/010 | Batch 250/938 | Cost: 0.0720\n",
      "Epoch: 005/010 | Batch 300/938 | Cost: 0.0497\n",
      "Epoch: 005/010 | Batch 350/938 | Cost: 0.1058\n",
      "Epoch: 005/010 | Batch 400/938 | Cost: 0.2139\n",
      "Epoch: 005/010 | Batch 450/938 | Cost: 0.0602\n",
      "Epoch: 005/010 | Batch 500/938 | Cost: 0.0689\n",
      "Epoch: 005/010 | Batch 550/938 | Cost: 0.1355\n",
      "Epoch: 005/010 | Batch 600/938 | Cost: 0.1659\n",
      "Epoch: 005/010 | Batch 650/938 | Cost: 0.1504\n",
      "Epoch: 005/010 | Batch 700/938 | Cost: 0.0403\n",
      "Epoch: 005/010 | Batch 750/938 | Cost: 0.3422\n",
      "Epoch: 005/010 | Batch 800/938 | Cost: 0.3299\n",
      "Epoch: 005/010 | Batch 850/938 | Cost: 0.2327\n",
      "Epoch: 005/010 | Batch 900/938 | Cost: 0.0171\n",
      "Epoch: 005/010 training accuracy: 97.51%\n",
      "Time elapsed: 1.23 min\n",
      "Epoch: 006/010 | Batch 000/938 | Cost: 0.0548\n",
      "Epoch: 006/010 | Batch 050/938 | Cost: 0.2781\n",
      "Epoch: 006/010 | Batch 100/938 | Cost: 0.0657\n",
      "Epoch: 006/010 | Batch 150/938 | Cost: 0.0444\n",
      "Epoch: 006/010 | Batch 200/938 | Cost: 0.0057\n",
      "Epoch: 006/010 | Batch 250/938 | Cost: 0.1058\n",
      "Epoch: 006/010 | Batch 300/938 | Cost: 0.1610\n",
      "Epoch: 006/010 | Batch 350/938 | Cost: 0.0353\n",
      "Epoch: 006/010 | Batch 400/938 | Cost: 0.2474\n",
      "Epoch: 006/010 | Batch 450/938 | Cost: 0.1038\n",
      "Epoch: 006/010 | Batch 500/938 | Cost: 0.2918\n",
      "Epoch: 006/010 | Batch 550/938 | Cost: 0.1360\n",
      "Epoch: 006/010 | Batch 600/938 | Cost: 0.1977\n",
      "Epoch: 006/010 | Batch 650/938 | Cost: 0.0314\n",
      "Epoch: 006/010 | Batch 700/938 | Cost: 0.0968\n",
      "Epoch: 006/010 | Batch 750/938 | Cost: 0.2215\n",
      "Epoch: 006/010 | Batch 800/938 | Cost: 0.0328\n",
      "Epoch: 006/010 | Batch 850/938 | Cost: 0.2423\n",
      "Epoch: 006/010 | Batch 900/938 | Cost: 0.1192\n",
      "Epoch: 006/010 training accuracy: 97.47%\n",
      "Time elapsed: 1.48 min\n",
      "Epoch: 007/010 | Batch 000/938 | Cost: 0.0126\n",
      "Epoch: 007/010 | Batch 050/938 | Cost: 0.0735\n",
      "Epoch: 007/010 | Batch 100/938 | Cost: 0.2426\n",
      "Epoch: 007/010 | Batch 150/938 | Cost: 0.0736\n",
      "Epoch: 007/010 | Batch 200/938 | Cost: 0.1387\n",
      "Epoch: 007/010 | Batch 250/938 | Cost: 0.2173\n",
      "Epoch: 007/010 | Batch 300/938 | Cost: 0.0127\n",
      "Epoch: 007/010 | Batch 350/938 | Cost: 0.1131\n",
      "Epoch: 007/010 | Batch 400/938 | Cost: 0.2219\n",
      "Epoch: 007/010 | Batch 450/938 | Cost: 0.0127\n",
      "Epoch: 007/010 | Batch 500/938 | Cost: 0.0905\n",
      "Epoch: 007/010 | Batch 550/938 | Cost: 0.2466\n",
      "Epoch: 007/010 | Batch 600/938 | Cost: 0.0065\n",
      "Epoch: 007/010 | Batch 650/938 | Cost: 0.1477\n",
      "Epoch: 007/010 | Batch 700/938 | Cost: 0.0183\n",
      "Epoch: 007/010 | Batch 750/938 | Cost: 0.0534\n",
      "Epoch: 007/010 | Batch 800/938 | Cost: 0.1139\n",
      "Epoch: 007/010 | Batch 850/938 | Cost: 0.1177\n",
      "Epoch: 007/010 | Batch 900/938 | Cost: 0.0662\n",
      "Epoch: 007/010 training accuracy: 97.74%\n",
      "Time elapsed: 1.72 min\n",
      "Epoch: 008/010 | Batch 000/938 | Cost: 0.0276\n",
      "Epoch: 008/010 | Batch 050/938 | Cost: 0.1275\n",
      "Epoch: 008/010 | Batch 100/938 | Cost: 0.2151\n",
      "Epoch: 008/010 | Batch 150/938 | Cost: 0.0204\n",
      "Epoch: 008/010 | Batch 200/938 | Cost: 0.2154\n",
      "Epoch: 008/010 | Batch 250/938 | Cost: 0.0271\n",
      "Epoch: 008/010 | Batch 300/938 | Cost: 0.0523\n",
      "Epoch: 008/010 | Batch 350/938 | Cost: 0.1604\n",
      "Epoch: 008/010 | Batch 400/938 | Cost: 0.0888\n",
      "Epoch: 008/010 | Batch 450/938 | Cost: 0.0045\n",
      "Epoch: 008/010 | Batch 500/938 | Cost: 0.0288\n",
      "Epoch: 008/010 | Batch 550/938 | Cost: 0.1140\n",
      "Epoch: 008/010 | Batch 600/938 | Cost: 0.0849\n",
      "Epoch: 008/010 | Batch 650/938 | Cost: 0.0216\n",
      "Epoch: 008/010 | Batch 700/938 | Cost: 0.0294\n",
      "Epoch: 008/010 | Batch 750/938 | Cost: 0.0995\n",
      "Epoch: 008/010 | Batch 800/938 | Cost: 0.1159\n",
      "Epoch: 008/010 | Batch 850/938 | Cost: 0.1599\n",
      "Epoch: 008/010 | Batch 900/938 | Cost: 0.1317\n",
      "Epoch: 008/010 training accuracy: 98.29%\n",
      "Time elapsed: 1.97 min\n",
      "Epoch: 009/010 | Batch 000/938 | Cost: 0.1071\n",
      "Epoch: 009/010 | Batch 050/938 | Cost: 0.0580\n",
      "Epoch: 009/010 | Batch 100/938 | Cost: 0.1777\n",
      "Epoch: 009/010 | Batch 150/938 | Cost: 0.2850\n",
      "Epoch: 009/010 | Batch 200/938 | Cost: 0.1229\n",
      "Epoch: 009/010 | Batch 250/938 | Cost: 0.0672\n",
      "Epoch: 009/010 | Batch 300/938 | Cost: 0.2009\n",
      "Epoch: 009/010 | Batch 350/938 | Cost: 0.0110\n",
      "Epoch: 009/010 | Batch 400/938 | Cost: 0.2604\n",
      "Epoch: 009/010 | Batch 450/938 | Cost: 0.0801\n",
      "Epoch: 009/010 | Batch 500/938 | Cost: 0.0092\n",
      "Epoch: 009/010 | Batch 550/938 | Cost: 0.1360\n",
      "Epoch: 009/010 | Batch 600/938 | Cost: 0.0664\n",
      "Epoch: 009/010 | Batch 650/938 | Cost: 0.0886\n",
      "Epoch: 009/010 | Batch 700/938 | Cost: 0.0630\n",
      "Epoch: 009/010 | Batch 750/938 | Cost: 0.0784\n",
      "Epoch: 009/010 | Batch 800/938 | Cost: 0.1736\n",
      "Epoch: 009/010 | Batch 850/938 | Cost: 0.0855\n",
      "Epoch: 009/010 | Batch 900/938 | Cost: 0.2815\n",
      "Epoch: 009/010 training accuracy: 97.74%\n",
      "Time elapsed: 2.21 min\n",
      "Epoch: 010/010 | Batch 000/938 | Cost: 0.0024\n",
      "Epoch: 010/010 | Batch 050/938 | Cost: 0.0497\n",
      "Epoch: 010/010 | Batch 100/938 | Cost: 0.0888\n",
      "Epoch: 010/010 | Batch 150/938 | Cost: 0.1719\n",
      "Epoch: 010/010 | Batch 200/938 | Cost: 0.1729\n",
      "Epoch: 010/010 | Batch 250/938 | Cost: 0.0543\n",
      "Epoch: 010/010 | Batch 300/938 | Cost: 0.3770\n",
      "Epoch: 010/010 | Batch 350/938 | Cost: 0.0270\n",
      "Epoch: 010/010 | Batch 400/938 | Cost: 0.1400\n",
      "Epoch: 010/010 | Batch 450/938 | Cost: 0.0526\n",
      "Epoch: 010/010 | Batch 500/938 | Cost: 0.1984\n",
      "Epoch: 010/010 | Batch 550/938 | Cost: 0.1677\n",
      "Epoch: 010/010 | Batch 600/938 | Cost: 0.0550\n",
      "Epoch: 010/010 | Batch 650/938 | Cost: 0.0294\n",
      "Epoch: 010/010 | Batch 700/938 | Cost: 0.0465\n",
      "Epoch: 010/010 | Batch 750/938 | Cost: 0.1103\n",
      "Epoch: 010/010 | Batch 800/938 | Cost: 0.0272\n",
      "Epoch: 010/010 | Batch 850/938 | Cost: 0.1376\n",
      "Epoch: 010/010 | Batch 900/938 | Cost: 0.0279\n",
      "Epoch: 010/010 training accuracy: 98.09%\n",
      "Time elapsed: 2.46 min\n",
      "Total Training Time: 2.46 min\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(random_seed)\n",
    "model = MultilayerPerceptron(num_features=num_features,\n",
    "                             num_classes=num_classes)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  \n",
    "\n",
    "###################################################################\n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.view(-1, 28*28).to(device)\n",
    "        targets = targets.to(device)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        \n",
    "        #########################################################\n",
    "        #########################################################\n",
    "        ### GRADIENT CLIPPING\n",
    "        torch.nn.utils.clip_grad_value_(model.parameters(), 1.)\n",
    "        #########################################################\n",
    "        #########################################################\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    with torch.set_grad_enabled(False):\n",
    "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader)))\n",
    "        \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 96.80%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 - Custom Lower and Upper Bounds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/938 | Cost: 2.3054\n",
      "Epoch: 001/010 | Batch 050/938 | Cost: 0.5977\n",
      "Epoch: 001/010 | Batch 100/938 | Cost: 0.4369\n",
      "Epoch: 001/010 | Batch 150/938 | Cost: 0.3053\n",
      "Epoch: 001/010 | Batch 200/938 | Cost: 0.3661\n",
      "Epoch: 001/010 | Batch 250/938 | Cost: 0.1908\n",
      "Epoch: 001/010 | Batch 300/938 | Cost: 0.2845\n",
      "Epoch: 001/010 | Batch 350/938 | Cost: 0.1928\n",
      "Epoch: 001/010 | Batch 400/938 | Cost: 0.2715\n",
      "Epoch: 001/010 | Batch 450/938 | Cost: 0.2338\n",
      "Epoch: 001/010 | Batch 500/938 | Cost: 0.3923\n",
      "Epoch: 001/010 | Batch 550/938 | Cost: 0.0973\n",
      "Epoch: 001/010 | Batch 600/938 | Cost: 0.3142\n",
      "Epoch: 001/010 | Batch 650/938 | Cost: 0.5024\n",
      "Epoch: 001/010 | Batch 700/938 | Cost: 0.1549\n",
      "Epoch: 001/010 | Batch 750/938 | Cost: 0.1906\n",
      "Epoch: 001/010 | Batch 800/938 | Cost: 0.3325\n",
      "Epoch: 001/010 | Batch 850/938 | Cost: 0.2060\n",
      "Epoch: 001/010 | Batch 900/938 | Cost: 0.1301\n",
      "Epoch: 001/010 training accuracy: 94.76%\n",
      "Time elapsed: 0.24 min\n",
      "Epoch: 002/010 | Batch 000/938 | Cost: 0.2553\n",
      "Epoch: 002/010 | Batch 050/938 | Cost: 0.1858\n",
      "Epoch: 002/010 | Batch 100/938 | Cost: 0.2514\n",
      "Epoch: 002/010 | Batch 150/938 | Cost: 0.1413\n",
      "Epoch: 002/010 | Batch 200/938 | Cost: 0.3071\n",
      "Epoch: 002/010 | Batch 250/938 | Cost: 0.6133\n",
      "Epoch: 002/010 | Batch 300/938 | Cost: 0.1657\n",
      "Epoch: 002/010 | Batch 350/938 | Cost: 0.0828\n",
      "Epoch: 002/010 | Batch 400/938 | Cost: 0.0733\n",
      "Epoch: 002/010 | Batch 450/938 | Cost: 0.3012\n",
      "Epoch: 002/010 | Batch 500/938 | Cost: 0.1857\n",
      "Epoch: 002/010 | Batch 550/938 | Cost: 0.3618\n",
      "Epoch: 002/010 | Batch 600/938 | Cost: 0.0777\n",
      "Epoch: 002/010 | Batch 650/938 | Cost: 0.2648\n",
      "Epoch: 002/010 | Batch 700/938 | Cost: 0.0242\n",
      "Epoch: 002/010 | Batch 750/938 | Cost: 0.1050\n",
      "Epoch: 002/010 | Batch 800/938 | Cost: 0.2148\n",
      "Epoch: 002/010 | Batch 850/938 | Cost: 0.0817\n",
      "Epoch: 002/010 | Batch 900/938 | Cost: 0.1354\n",
      "Epoch: 002/010 training accuracy: 97.04%\n",
      "Time elapsed: 0.49 min\n",
      "Epoch: 003/010 | Batch 000/938 | Cost: 0.1346\n",
      "Epoch: 003/010 | Batch 050/938 | Cost: 0.0825\n",
      "Epoch: 003/010 | Batch 100/938 | Cost: 0.0771\n",
      "Epoch: 003/010 | Batch 150/938 | Cost: 0.2360\n",
      "Epoch: 003/010 | Batch 200/938 | Cost: 0.0730\n",
      "Epoch: 003/010 | Batch 250/938 | Cost: 0.1499\n",
      "Epoch: 003/010 | Batch 300/938 | Cost: 0.0410\n",
      "Epoch: 003/010 | Batch 350/938 | Cost: 0.2091\n",
      "Epoch: 003/010 | Batch 400/938 | Cost: 0.0738\n",
      "Epoch: 003/010 | Batch 450/938 | Cost: 0.0889\n",
      "Epoch: 003/010 | Batch 500/938 | Cost: 0.3630\n",
      "Epoch: 003/010 | Batch 550/938 | Cost: 0.0312\n",
      "Epoch: 003/010 | Batch 600/938 | Cost: 0.0782\n",
      "Epoch: 003/010 | Batch 650/938 | Cost: 0.1753\n",
      "Epoch: 003/010 | Batch 700/938 | Cost: 0.0286\n",
      "Epoch: 003/010 | Batch 750/938 | Cost: 0.2166\n",
      "Epoch: 003/010 | Batch 800/938 | Cost: 0.0627\n",
      "Epoch: 003/010 | Batch 850/938 | Cost: 0.0204\n",
      "Epoch: 003/010 | Batch 900/938 | Cost: 0.2867\n",
      "Epoch: 003/010 training accuracy: 96.72%\n",
      "Time elapsed: 0.73 min\n",
      "Epoch: 004/010 | Batch 000/938 | Cost: 0.0207\n",
      "Epoch: 004/010 | Batch 050/938 | Cost: 0.0499\n",
      "Epoch: 004/010 | Batch 100/938 | Cost: 0.1858\n",
      "Epoch: 004/010 | Batch 150/938 | Cost: 0.2015\n",
      "Epoch: 004/010 | Batch 200/938 | Cost: 0.0285\n",
      "Epoch: 004/010 | Batch 250/938 | Cost: 0.0029\n",
      "Epoch: 004/010 | Batch 300/938 | Cost: 0.1746\n",
      "Epoch: 004/010 | Batch 350/938 | Cost: 0.3149\n",
      "Epoch: 004/010 | Batch 400/938 | Cost: 0.1773\n",
      "Epoch: 004/010 | Batch 450/938 | Cost: 0.1013\n",
      "Epoch: 004/010 | Batch 500/938 | Cost: 0.1665\n",
      "Epoch: 004/010 | Batch 550/938 | Cost: 0.1540\n",
      "Epoch: 004/010 | Batch 600/938 | Cost: 0.1822\n",
      "Epoch: 004/010 | Batch 650/938 | Cost: 0.1506\n",
      "Epoch: 004/010 | Batch 700/938 | Cost: 0.0224\n",
      "Epoch: 004/010 | Batch 750/938 | Cost: 0.1400\n",
      "Epoch: 004/010 | Batch 800/938 | Cost: 0.2262\n",
      "Epoch: 004/010 | Batch 850/938 | Cost: 0.0679\n",
      "Epoch: 004/010 | Batch 900/938 | Cost: 0.0020\n",
      "Epoch: 004/010 training accuracy: 97.63%\n",
      "Time elapsed: 0.98 min\n",
      "Epoch: 005/010 | Batch 000/938 | Cost: 0.0508\n",
      "Epoch: 005/010 | Batch 050/938 | Cost: 0.0585\n",
      "Epoch: 005/010 | Batch 100/938 | Cost: 0.1441\n",
      "Epoch: 005/010 | Batch 150/938 | Cost: 0.0862\n",
      "Epoch: 005/010 | Batch 200/938 | Cost: 0.0284\n",
      "Epoch: 005/010 | Batch 250/938 | Cost: 0.0977\n",
      "Epoch: 005/010 | Batch 300/938 | Cost: 0.0565\n",
      "Epoch: 005/010 | Batch 350/938 | Cost: 0.0272\n",
      "Epoch: 005/010 | Batch 400/938 | Cost: 0.2603\n",
      "Epoch: 005/010 | Batch 450/938 | Cost: 0.1202\n",
      "Epoch: 005/010 | Batch 500/938 | Cost: 0.0612\n",
      "Epoch: 005/010 | Batch 550/938 | Cost: 0.0833\n",
      "Epoch: 005/010 | Batch 600/938 | Cost: 0.1666\n",
      "Epoch: 005/010 | Batch 650/938 | Cost: 0.2642\n",
      "Epoch: 005/010 | Batch 700/938 | Cost: 0.1884\n",
      "Epoch: 005/010 | Batch 750/938 | Cost: 0.1608\n",
      "Epoch: 005/010 | Batch 800/938 | Cost: 0.1029\n",
      "Epoch: 005/010 | Batch 850/938 | Cost: 0.1178\n",
      "Epoch: 005/010 | Batch 900/938 | Cost: 0.0709\n",
      "Epoch: 005/010 training accuracy: 97.58%\n",
      "Time elapsed: 1.23 min\n",
      "Epoch: 006/010 | Batch 000/938 | Cost: 0.0642\n",
      "Epoch: 006/010 | Batch 050/938 | Cost: 0.3518\n",
      "Epoch: 006/010 | Batch 100/938 | Cost: 0.1134\n",
      "Epoch: 006/010 | Batch 150/938 | Cost: 0.0821\n",
      "Epoch: 006/010 | Batch 200/938 | Cost: 0.0645\n",
      "Epoch: 006/010 | Batch 250/938 | Cost: 0.0486\n",
      "Epoch: 006/010 | Batch 300/938 | Cost: 0.0972\n",
      "Epoch: 006/010 | Batch 350/938 | Cost: 0.2861\n",
      "Epoch: 006/010 | Batch 400/938 | Cost: 0.1126\n",
      "Epoch: 006/010 | Batch 450/938 | Cost: 0.1479\n",
      "Epoch: 006/010 | Batch 500/938 | Cost: 0.2181\n",
      "Epoch: 006/010 | Batch 550/938 | Cost: 0.0674\n",
      "Epoch: 006/010 | Batch 600/938 | Cost: 0.0705\n",
      "Epoch: 006/010 | Batch 650/938 | Cost: 0.1032\n",
      "Epoch: 006/010 | Batch 700/938 | Cost: 0.1529\n",
      "Epoch: 006/010 | Batch 750/938 | Cost: 0.2484\n",
      "Epoch: 006/010 | Batch 800/938 | Cost: 0.0432\n",
      "Epoch: 006/010 | Batch 850/938 | Cost: 0.0821\n",
      "Epoch: 006/010 | Batch 900/938 | Cost: 0.1152\n",
      "Epoch: 006/010 training accuracy: 97.09%\n",
      "Time elapsed: 1.47 min\n",
      "Epoch: 007/010 | Batch 000/938 | Cost: 0.0418\n",
      "Epoch: 007/010 | Batch 050/938 | Cost: 0.0527\n",
      "Epoch: 007/010 | Batch 100/938 | Cost: 0.3778\n",
      "Epoch: 007/010 | Batch 150/938 | Cost: 0.1742\n",
      "Epoch: 007/010 | Batch 200/938 | Cost: 0.0725\n",
      "Epoch: 007/010 | Batch 250/938 | Cost: 0.1187\n",
      "Epoch: 007/010 | Batch 300/938 | Cost: 0.0980\n",
      "Epoch: 007/010 | Batch 350/938 | Cost: 0.0077\n",
      "Epoch: 007/010 | Batch 400/938 | Cost: 0.1274\n",
      "Epoch: 007/010 | Batch 450/938 | Cost: 0.1387\n",
      "Epoch: 007/010 | Batch 500/938 | Cost: 0.1959\n",
      "Epoch: 007/010 | Batch 550/938 | Cost: 0.0874\n",
      "Epoch: 007/010 | Batch 600/938 | Cost: 0.2559\n",
      "Epoch: 007/010 | Batch 650/938 | Cost: 0.1413\n",
      "Epoch: 007/010 | Batch 700/938 | Cost: 0.1285\n",
      "Epoch: 007/010 | Batch 750/938 | Cost: 0.1931\n",
      "Epoch: 007/010 | Batch 800/938 | Cost: 0.1151\n",
      "Epoch: 007/010 | Batch 850/938 | Cost: 0.1889\n",
      "Epoch: 007/010 | Batch 900/938 | Cost: 0.5518\n",
      "Epoch: 007/010 training accuracy: 86.62%\n",
      "Time elapsed: 1.72 min\n",
      "Epoch: 008/010 | Batch 000/938 | Cost: 0.3283\n",
      "Epoch: 008/010 | Batch 050/938 | Cost: 0.1818\n",
      "Epoch: 008/010 | Batch 100/938 | Cost: 0.1827\n",
      "Epoch: 008/010 | Batch 150/938 | Cost: 0.0844\n",
      "Epoch: 008/010 | Batch 200/938 | Cost: 0.4017\n",
      "Epoch: 008/010 | Batch 250/938 | Cost: 0.0129\n",
      "Epoch: 008/010 | Batch 300/938 | Cost: 0.0155\n",
      "Epoch: 008/010 | Batch 350/938 | Cost: 0.1844\n",
      "Epoch: 008/010 | Batch 400/938 | Cost: 0.1146\n",
      "Epoch: 008/010 | Batch 450/938 | Cost: 0.0566\n",
      "Epoch: 008/010 | Batch 500/938 | Cost: 0.0895\n",
      "Epoch: 008/010 | Batch 550/938 | Cost: 0.1851\n",
      "Epoch: 008/010 | Batch 600/938 | Cost: 0.1134\n",
      "Epoch: 008/010 | Batch 650/938 | Cost: 0.0838\n",
      "Epoch: 008/010 | Batch 700/938 | Cost: 0.1157\n",
      "Epoch: 008/010 | Batch 750/938 | Cost: 0.2275\n",
      "Epoch: 008/010 | Batch 800/938 | Cost: 0.5753\n",
      "Epoch: 008/010 | Batch 850/938 | Cost: 0.8735\n",
      "Epoch: 008/010 | Batch 900/938 | Cost: 0.7114\n",
      "Epoch: 008/010 training accuracy: 85.51%\n",
      "Time elapsed: 1.97 min\n",
      "Epoch: 009/010 | Batch 000/938 | Cost: 0.4851\n",
      "Epoch: 009/010 | Batch 050/938 | Cost: 0.4595\n",
      "Epoch: 009/010 | Batch 100/938 | Cost: 0.1939\n",
      "Epoch: 009/010 | Batch 150/938 | Cost: 0.1813\n",
      "Epoch: 009/010 | Batch 200/938 | Cost: 0.4969\n",
      "Epoch: 009/010 | Batch 250/938 | Cost: 0.4874\n",
      "Epoch: 009/010 | Batch 300/938 | Cost: 0.1605\n",
      "Epoch: 009/010 | Batch 350/938 | Cost: 0.0899\n",
      "Epoch: 009/010 | Batch 400/938 | Cost: 0.3318\n",
      "Epoch: 009/010 | Batch 450/938 | Cost: 0.0524\n",
      "Epoch: 009/010 | Batch 500/938 | Cost: 0.0215\n",
      "Epoch: 009/010 | Batch 550/938 | Cost: 0.0997\n",
      "Epoch: 009/010 | Batch 600/938 | Cost: 0.0541\n",
      "Epoch: 009/010 | Batch 650/938 | Cost: 0.3480\n",
      "Epoch: 009/010 | Batch 700/938 | Cost: 0.0736\n",
      "Epoch: 009/010 | Batch 750/938 | Cost: 0.1682\n",
      "Epoch: 009/010 | Batch 800/938 | Cost: 0.2877\n",
      "Epoch: 009/010 | Batch 850/938 | Cost: 0.0539\n",
      "Epoch: 009/010 | Batch 900/938 | Cost: 0.2708\n",
      "Epoch: 009/010 training accuracy: 95.67%\n",
      "Time elapsed: 2.21 min\n",
      "Epoch: 010/010 | Batch 000/938 | Cost: 0.0531\n",
      "Epoch: 010/010 | Batch 050/938 | Cost: 0.0453\n",
      "Epoch: 010/010 | Batch 100/938 | Cost: 1.8852\n",
      "Epoch: 010/010 | Batch 150/938 | Cost: 0.1455\n",
      "Epoch: 010/010 | Batch 200/938 | Cost: 0.2089\n",
      "Epoch: 010/010 | Batch 250/938 | Cost: 0.0155\n",
      "Epoch: 010/010 | Batch 300/938 | Cost: 0.9183\n",
      "Epoch: 010/010 | Batch 350/938 | Cost: 0.2231\n",
      "Epoch: 010/010 | Batch 400/938 | Cost: 0.3704\n",
      "Epoch: 010/010 | Batch 450/938 | Cost: 0.1086\n",
      "Epoch: 010/010 | Batch 500/938 | Cost: 0.3775\n",
      "Epoch: 010/010 | Batch 550/938 | Cost: 0.4196\n",
      "Epoch: 010/010 | Batch 600/938 | Cost: 0.2836\n",
      "Epoch: 010/010 | Batch 650/938 | Cost: 0.1170\n",
      "Epoch: 010/010 | Batch 700/938 | Cost: 0.2631\n",
      "Epoch: 010/010 | Batch 750/938 | Cost: 0.1400\n",
      "Epoch: 010/010 | Batch 800/938 | Cost: 0.1048\n",
      "Epoch: 010/010 | Batch 850/938 | Cost: 0.7937\n",
      "Epoch: 010/010 | Batch 900/938 | Cost: 0.2107\n",
      "Epoch: 010/010 training accuracy: 87.98%\n",
      "Time elapsed: 2.46 min\n",
      "Total Training Time: 2.46 min\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(random_seed)\n",
    "model = MultilayerPerceptron(num_features=num_features,\n",
    "                             num_classes=num_classes)\n",
    "\n",
    "#########################################################\n",
    "#########################################################\n",
    "### GRADIENT CLIPPING\n",
    "for p in model.parameters():\n",
    "    p.register_hook(lambda grad: torch.clamp(grad, -0.1, 1.0))\n",
    "#########################################################\n",
    "#########################################################\n",
    "    \n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  \n",
    "\n",
    "###################################################################\n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.view(-1, 28*28).to(device)\n",
    "        targets = targets.to(device)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    with torch.set_grad_enabled(False):\n",
    "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader)))\n",
    "        \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 86.94%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 - Norm-clipping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/938 | Cost: 2.3054\n",
      "Epoch: 001/010 | Batch 050/938 | Cost: 0.5121\n",
      "Epoch: 001/010 | Batch 100/938 | Cost: 0.3424\n",
      "Epoch: 001/010 | Batch 150/938 | Cost: 0.2765\n",
      "Epoch: 001/010 | Batch 200/938 | Cost: 0.5126\n",
      "Epoch: 001/010 | Batch 250/938 | Cost: 0.1481\n",
      "Epoch: 001/010 | Batch 300/938 | Cost: 0.2240\n",
      "Epoch: 001/010 | Batch 350/938 | Cost: 0.1948\n",
      "Epoch: 001/010 | Batch 400/938 | Cost: 0.0655\n",
      "Epoch: 001/010 | Batch 450/938 | Cost: 0.1893\n",
      "Epoch: 001/010 | Batch 500/938 | Cost: 0.4133\n",
      "Epoch: 001/010 | Batch 550/938 | Cost: 0.0375\n",
      "Epoch: 001/010 | Batch 600/938 | Cost: 0.2691\n",
      "Epoch: 001/010 | Batch 650/938 | Cost: 0.3342\n",
      "Epoch: 001/010 | Batch 700/938 | Cost: 0.1662\n",
      "Epoch: 001/010 | Batch 750/938 | Cost: 0.0702\n",
      "Epoch: 001/010 | Batch 800/938 | Cost: 0.4246\n",
      "Epoch: 001/010 | Batch 850/938 | Cost: 0.2282\n",
      "Epoch: 001/010 | Batch 900/938 | Cost: 0.0459\n",
      "Epoch: 001/010 training accuracy: 94.99%\n",
      "Time elapsed: 0.25 min\n",
      "Epoch: 002/010 | Batch 000/938 | Cost: 0.2188\n",
      "Epoch: 002/010 | Batch 050/938 | Cost: 0.3042\n",
      "Epoch: 002/010 | Batch 100/938 | Cost: 0.1391\n",
      "Epoch: 002/010 | Batch 150/938 | Cost: 0.1453\n",
      "Epoch: 002/010 | Batch 200/938 | Cost: 0.3031\n",
      "Epoch: 002/010 | Batch 250/938 | Cost: 0.1398\n",
      "Epoch: 002/010 | Batch 300/938 | Cost: 0.0868\n",
      "Epoch: 002/010 | Batch 350/938 | Cost: 0.1679\n",
      "Epoch: 002/010 | Batch 400/938 | Cost: 0.0480\n",
      "Epoch: 002/010 | Batch 450/938 | Cost: 0.2823\n",
      "Epoch: 002/010 | Batch 500/938 | Cost: 0.2307\n",
      "Epoch: 002/010 | Batch 550/938 | Cost: 0.1610\n",
      "Epoch: 002/010 | Batch 600/938 | Cost: 0.0972\n",
      "Epoch: 002/010 | Batch 650/938 | Cost: 0.3210\n",
      "Epoch: 002/010 | Batch 700/938 | Cost: 0.0697\n",
      "Epoch: 002/010 | Batch 750/938 | Cost: 0.0879\n",
      "Epoch: 002/010 | Batch 800/938 | Cost: 0.2113\n",
      "Epoch: 002/010 | Batch 850/938 | Cost: 0.2496\n",
      "Epoch: 002/010 | Batch 900/938 | Cost: 0.2453\n",
      "Epoch: 002/010 training accuracy: 96.15%\n",
      "Time elapsed: 0.49 min\n",
      "Epoch: 003/010 | Batch 000/938 | Cost: 0.1779\n",
      "Epoch: 003/010 | Batch 050/938 | Cost: 0.0618\n",
      "Epoch: 003/010 | Batch 100/938 | Cost: 0.0570\n",
      "Epoch: 003/010 | Batch 150/938 | Cost: 0.2510\n",
      "Epoch: 003/010 | Batch 200/938 | Cost: 0.1193\n",
      "Epoch: 003/010 | Batch 250/938 | Cost: 0.2530\n",
      "Epoch: 003/010 | Batch 300/938 | Cost: 0.1220\n",
      "Epoch: 003/010 | Batch 350/938 | Cost: 0.2401\n",
      "Epoch: 003/010 | Batch 400/938 | Cost: 0.0520\n",
      "Epoch: 003/010 | Batch 450/938 | Cost: 0.0262\n",
      "Epoch: 003/010 | Batch 500/938 | Cost: 0.2961\n",
      "Epoch: 003/010 | Batch 550/938 | Cost: 0.0030\n",
      "Epoch: 003/010 | Batch 600/938 | Cost: 0.1998\n",
      "Epoch: 003/010 | Batch 650/938 | Cost: 0.1968\n",
      "Epoch: 003/010 | Batch 700/938 | Cost: 0.0499\n",
      "Epoch: 003/010 | Batch 750/938 | Cost: 0.1742\n",
      "Epoch: 003/010 | Batch 800/938 | Cost: 0.1034\n",
      "Epoch: 003/010 | Batch 850/938 | Cost: 0.0437\n",
      "Epoch: 003/010 | Batch 900/938 | Cost: 0.1414\n",
      "Epoch: 003/010 training accuracy: 97.30%\n",
      "Time elapsed: 0.74 min\n",
      "Epoch: 004/010 | Batch 000/938 | Cost: 0.1098\n",
      "Epoch: 004/010 | Batch 050/938 | Cost: 0.0060\n",
      "Epoch: 004/010 | Batch 100/938 | Cost: 0.3551\n",
      "Epoch: 004/010 | Batch 150/938 | Cost: 0.3143\n",
      "Epoch: 004/010 | Batch 200/938 | Cost: 0.0527\n",
      "Epoch: 004/010 | Batch 250/938 | Cost: 0.0204\n",
      "Epoch: 004/010 | Batch 300/938 | Cost: 0.0289\n",
      "Epoch: 004/010 | Batch 350/938 | Cost: 0.2386\n",
      "Epoch: 004/010 | Batch 400/938 | Cost: 0.0694\n",
      "Epoch: 004/010 | Batch 450/938 | Cost: 0.1200\n",
      "Epoch: 004/010 | Batch 500/938 | Cost: 0.0797\n",
      "Epoch: 004/010 | Batch 550/938 | Cost: 0.0891\n",
      "Epoch: 004/010 | Batch 600/938 | Cost: 0.3322\n",
      "Epoch: 004/010 | Batch 650/938 | Cost: 0.1640\n",
      "Epoch: 004/010 | Batch 700/938 | Cost: 0.1170\n",
      "Epoch: 004/010 | Batch 750/938 | Cost: 0.2028\n",
      "Epoch: 004/010 | Batch 800/938 | Cost: 0.2188\n",
      "Epoch: 004/010 | Batch 850/938 | Cost: 0.0575\n",
      "Epoch: 004/010 | Batch 900/938 | Cost: 0.0180\n",
      "Epoch: 004/010 training accuracy: 96.86%\n",
      "Time elapsed: 0.98 min\n",
      "Epoch: 005/010 | Batch 000/938 | Cost: 0.0779\n",
      "Epoch: 005/010 | Batch 050/938 | Cost: 0.1183\n",
      "Epoch: 005/010 | Batch 100/938 | Cost: 0.1184\n",
      "Epoch: 005/010 | Batch 150/938 | Cost: 0.0815\n",
      "Epoch: 005/010 | Batch 200/938 | Cost: 0.0691\n",
      "Epoch: 005/010 | Batch 250/938 | Cost: 0.0784\n",
      "Epoch: 005/010 | Batch 300/938 | Cost: 0.1464\n",
      "Epoch: 005/010 | Batch 350/938 | Cost: 0.1488\n",
      "Epoch: 005/010 | Batch 400/938 | Cost: 0.2636\n",
      "Epoch: 005/010 | Batch 450/938 | Cost: 0.0839\n",
      "Epoch: 005/010 | Batch 500/938 | Cost: 0.1343\n",
      "Epoch: 005/010 | Batch 550/938 | Cost: 0.0514\n",
      "Epoch: 005/010 | Batch 600/938 | Cost: 0.1802\n",
      "Epoch: 005/010 | Batch 650/938 | Cost: 0.0681\n",
      "Epoch: 005/010 | Batch 700/938 | Cost: 0.0986\n",
      "Epoch: 005/010 | Batch 750/938 | Cost: 0.0930\n",
      "Epoch: 005/010 | Batch 800/938 | Cost: 0.1829\n",
      "Epoch: 005/010 | Batch 850/938 | Cost: 0.1694\n",
      "Epoch: 005/010 | Batch 900/938 | Cost: 0.0440\n",
      "Epoch: 005/010 training accuracy: 97.22%\n",
      "Time elapsed: 1.22 min\n",
      "Epoch: 006/010 | Batch 000/938 | Cost: 0.0142\n",
      "Epoch: 006/010 | Batch 050/938 | Cost: 0.3528\n",
      "Epoch: 006/010 | Batch 100/938 | Cost: 0.0710\n",
      "Epoch: 006/010 | Batch 150/938 | Cost: 0.0553\n",
      "Epoch: 006/010 | Batch 200/938 | Cost: 0.0084\n",
      "Epoch: 006/010 | Batch 250/938 | Cost: 0.1178\n",
      "Epoch: 006/010 | Batch 300/938 | Cost: 0.1271\n",
      "Epoch: 006/010 | Batch 350/938 | Cost: 0.0404\n",
      "Epoch: 006/010 | Batch 400/938 | Cost: 0.1435\n",
      "Epoch: 006/010 | Batch 450/938 | Cost: 0.1568\n",
      "Epoch: 006/010 | Batch 500/938 | Cost: 0.2100\n",
      "Epoch: 006/010 | Batch 550/938 | Cost: 0.0019\n",
      "Epoch: 006/010 | Batch 600/938 | Cost: 0.1721\n",
      "Epoch: 006/010 | Batch 650/938 | Cost: 0.0943\n",
      "Epoch: 006/010 | Batch 700/938 | Cost: 0.0913\n",
      "Epoch: 006/010 | Batch 750/938 | Cost: 0.1211\n",
      "Epoch: 006/010 | Batch 800/938 | Cost: 0.0890\n",
      "Epoch: 006/010 | Batch 850/938 | Cost: 0.0390\n",
      "Epoch: 006/010 | Batch 900/938 | Cost: 0.0521\n",
      "Epoch: 006/010 training accuracy: 97.79%\n",
      "Time elapsed: 1.47 min\n",
      "Epoch: 007/010 | Batch 000/938 | Cost: 0.0059\n",
      "Epoch: 007/010 | Batch 050/938 | Cost: 0.0371\n",
      "Epoch: 007/010 | Batch 100/938 | Cost: 0.2702\n",
      "Epoch: 007/010 | Batch 150/938 | Cost: 0.1142\n",
      "Epoch: 007/010 | Batch 200/938 | Cost: 0.0900\n",
      "Epoch: 007/010 | Batch 250/938 | Cost: 0.1922\n",
      "Epoch: 007/010 | Batch 300/938 | Cost: 0.0062\n",
      "Epoch: 007/010 | Batch 350/938 | Cost: 0.0435\n",
      "Epoch: 007/010 | Batch 400/938 | Cost: 0.0503\n",
      "Epoch: 007/010 | Batch 450/938 | Cost: 0.1411\n",
      "Epoch: 007/010 | Batch 500/938 | Cost: 0.1547\n",
      "Epoch: 007/010 | Batch 550/938 | Cost: 0.1858\n",
      "Epoch: 007/010 | Batch 600/938 | Cost: 0.0108\n",
      "Epoch: 007/010 | Batch 650/938 | Cost: 0.0569\n",
      "Epoch: 007/010 | Batch 700/938 | Cost: 0.0254\n",
      "Epoch: 007/010 | Batch 750/938 | Cost: 0.0635\n",
      "Epoch: 007/010 | Batch 800/938 | Cost: 0.2539\n",
      "Epoch: 007/010 | Batch 850/938 | Cost: 0.1338\n",
      "Epoch: 007/010 | Batch 900/938 | Cost: 0.3336\n",
      "Epoch: 007/010 training accuracy: 98.25%\n",
      "Time elapsed: 1.71 min\n",
      "Epoch: 008/010 | Batch 000/938 | Cost: 0.0215\n",
      "Epoch: 008/010 | Batch 050/938 | Cost: 0.2800\n",
      "Epoch: 008/010 | Batch 100/938 | Cost: 0.2627\n",
      "Epoch: 008/010 | Batch 150/938 | Cost: 0.0538\n",
      "Epoch: 008/010 | Batch 200/938 | Cost: 0.2164\n",
      "Epoch: 008/010 | Batch 250/938 | Cost: 0.0025\n",
      "Epoch: 008/010 | Batch 300/938 | Cost: 0.0021\n",
      "Epoch: 008/010 | Batch 350/938 | Cost: 0.1489\n",
      "Epoch: 008/010 | Batch 400/938 | Cost: 0.0997\n",
      "Epoch: 008/010 | Batch 450/938 | Cost: 0.0055\n",
      "Epoch: 008/010 | Batch 500/938 | Cost: 0.0181\n",
      "Epoch: 008/010 | Batch 550/938 | Cost: 0.1672\n",
      "Epoch: 008/010 | Batch 600/938 | Cost: 0.0538\n",
      "Epoch: 008/010 | Batch 650/938 | Cost: 0.0842\n",
      "Epoch: 008/010 | Batch 700/938 | Cost: 0.0941\n",
      "Epoch: 008/010 | Batch 750/938 | Cost: 0.0171\n",
      "Epoch: 008/010 | Batch 800/938 | Cost: 0.0638\n",
      "Epoch: 008/010 | Batch 850/938 | Cost: 0.2507\n",
      "Epoch: 008/010 | Batch 900/938 | Cost: 0.0568\n",
      "Epoch: 008/010 training accuracy: 98.31%\n",
      "Time elapsed: 1.96 min\n",
      "Epoch: 009/010 | Batch 000/938 | Cost: 0.0844\n",
      "Epoch: 009/010 | Batch 050/938 | Cost: 0.1087\n",
      "Epoch: 009/010 | Batch 100/938 | Cost: 0.0584\n",
      "Epoch: 009/010 | Batch 150/938 | Cost: 0.0544\n",
      "Epoch: 009/010 | Batch 200/938 | Cost: 0.0352\n",
      "Epoch: 009/010 | Batch 250/938 | Cost: 0.0189\n",
      "Epoch: 009/010 | Batch 300/938 | Cost: 0.0356\n",
      "Epoch: 009/010 | Batch 350/938 | Cost: 0.1357\n",
      "Epoch: 009/010 | Batch 400/938 | Cost: 0.2133\n",
      "Epoch: 009/010 | Batch 450/938 | Cost: 0.0081\n",
      "Epoch: 009/010 | Batch 500/938 | Cost: 0.0710\n",
      "Epoch: 009/010 | Batch 550/938 | Cost: 0.0652\n",
      "Epoch: 009/010 | Batch 600/938 | Cost: 0.0136\n",
      "Epoch: 009/010 | Batch 650/938 | Cost: 0.0772\n",
      "Epoch: 009/010 | Batch 700/938 | Cost: 0.0744\n",
      "Epoch: 009/010 | Batch 750/938 | Cost: 0.0388\n",
      "Epoch: 009/010 | Batch 800/938 | Cost: 0.0208\n",
      "Epoch: 009/010 | Batch 850/938 | Cost: 0.0114\n",
      "Epoch: 009/010 | Batch 900/938 | Cost: 0.0706\n",
      "Epoch: 009/010 training accuracy: 97.76%\n",
      "Time elapsed: 2.20 min\n",
      "Epoch: 010/010 | Batch 000/938 | Cost: 0.0773\n",
      "Epoch: 010/010 | Batch 050/938 | Cost: 0.0362\n",
      "Epoch: 010/010 | Batch 100/938 | Cost: 0.0406\n",
      "Epoch: 010/010 | Batch 150/938 | Cost: 0.0900\n",
      "Epoch: 010/010 | Batch 200/938 | Cost: 0.3629\n",
      "Epoch: 010/010 | Batch 250/938 | Cost: 0.0016\n",
      "Epoch: 010/010 | Batch 300/938 | Cost: 0.0314\n",
      "Epoch: 010/010 | Batch 350/938 | Cost: 0.0677\n",
      "Epoch: 010/010 | Batch 400/938 | Cost: 0.0821\n",
      "Epoch: 010/010 | Batch 450/938 | Cost: 0.0717\n",
      "Epoch: 010/010 | Batch 500/938 | Cost: 0.2704\n",
      "Epoch: 010/010 | Batch 550/938 | Cost: 0.1784\n",
      "Epoch: 010/010 | Batch 600/938 | Cost: 0.0899\n",
      "Epoch: 010/010 | Batch 650/938 | Cost: 0.0578\n",
      "Epoch: 010/010 | Batch 700/938 | Cost: 0.1572\n",
      "Epoch: 010/010 | Batch 750/938 | Cost: 0.0106\n",
      "Epoch: 010/010 | Batch 800/938 | Cost: 0.0714\n",
      "Epoch: 010/010 | Batch 850/938 | Cost: 0.0125\n",
      "Epoch: 010/010 | Batch 900/938 | Cost: 0.0235\n",
      "Epoch: 010/010 training accuracy: 98.38%\n",
      "Time elapsed: 2.45 min\n",
      "Total Training Time: 2.45 min\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(random_seed)\n",
    "model = MultilayerPerceptron(num_features=num_features,\n",
    "                             num_classes=num_classes)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  \n",
    "\n",
    "###################################################################\n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.view(-1, 28*28).to(device)\n",
    "        targets = targets.to(device)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        \n",
    "        #########################################################\n",
    "        #########################################################\n",
    "        ### GRADIENT CLIPPING\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1., norm_type=2)\n",
    "        #########################################################\n",
    "        #########################################################\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    with torch.set_grad_enabled(False):\n",
    "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader)))\n",
    "        \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 96.89%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.16.4\n",
      "torch       1.2.0\n",
      "torchvision 0.4.0a0+6b959ee\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
